from .base_env import BaseEnvMOMAB
import numpy as np
from utils import initial_thetalist, generate_diverse_theta, uniform_cxts, pareto_front, pareto_front_v2
from tqdm import trange
import pickle
import random

class MO_Context_Linear_env(BaseEnvMOMAB):
    def __init__(self, K, num_obj = 3, d = 10 , sig = 0.1 ,version = 0, fixed_context = False):
        super().__init__(K = K)
        self.fixed_context = fixed_context
        self.num_obj = num_obj
        self.sig = sig
        self.d = d
        self.version = version
        self.t = 0
        self.load_envs()
       
    def save_envs(self, time):
        for t in trange(time):
            if t + self.data_size > 4000:
                print("data_size over limit!")
                break
            contexts = uniform_cxts(self.K, self.d, seed = t + self.data_size)
            exp_rewards = np.matmul(contexts, self.theta_list.T)
            
            pareto_idx_old, pareto_regret_old = pareto_front(exp_rewards)
            pareto_idx_ours, pareto_regret_ours = pareto_front_v2(exp_rewards, pareto_idx_old)
            self.data_eval.append({'contexts' : contexts.tolist(),
                                   'exp_rewards' : exp_rewards.tolist(),
                                   'pareto_idx_old' : pareto_idx_old,
                                   'pareto_regret_old' : pareto_regret_old,
                                   'pareto_idx_ours' : pareto_idx_ours,
                                   'pareto_regret_ours' : pareto_regret_ours})
        
        data = {'theta_list' : self.theta_list.tolist(),
                'eval' : self.data_eval}

        with open('./data/dataset_K%d_d%d_L%d_v%d.pkl' % (self.K,self.d,self.num_obj, self.version), 'wb') as f:
            pickle.dump(data, f)
    
    def load_envs(self):
        try:
            with open('./data/dataset_K%d_d%d_L%d_v%d.pkl' % (self.K,self.d,self.num_obj, self.version), 'rb') as f:
                data = pickle.load(f)
        
            self.theta_list = np.array(data['theta_list'])
            self.data_eval = data['eval']
            print("data is fully loaded")
            
        except:
            self.theta_list = generate_diverse_theta(self.d, self.num_obj)
            self.data_eval = []
        
        self.data_size = len(self.data_eval)
        print("data_length : ", self.data_size)
        if self.data_size == 0:
            print("data is not loaded, making environment")
            self.save_envs(100)
            self.load_envs()

    def warm_up(self):
        # self.t = 13
        self.contexts = np.array(self.data_eval[self.t]['contexts'])
        self.exp_rewards = np.array(self.data_eval[self.t]['exp_rewards'])
        self.pareto_idx_old = self.data_eval[self.t]['pareto_idx_old']
        self.pareto_regret_old = self.data_eval[self.t]['pareto_regret_old']
        self.pareto_idx_ours = self.data_eval[self.t]['pareto_idx_ours']
        self.pareto_regret_ours = self.data_eval[self.t]['pareto_regret_ours']


    def view_context(self):
        return self.contexts

    def action_reward(self, idx) :
        return self.exp_rewards[idx] + np.random.normal(0, self.sig, size = self.num_obj)
        
    def mean_reward(self, idx) :
        return self.exp_rewards[idx]
    
    def update_env(self):
        if not self.fixed_context:
            if self.t < self.data_size-1:
                self.t += 1
            else:
                self.t = random.randint(0, self.data_size - 1)
            self.contexts = np.array(self.data_eval[self.t]['contexts'])
            self.exp_rewards = np.array(self.data_eval[self.t]['exp_rewards'])
            self.pareto_idx_old = self.data_eval[self.t]['pareto_idx_old']
            self.pareto_regret_old = self.data_eval[self.t]['pareto_regret_old']
            self.pareto_idx_ours = self.data_eval[self.t]['pareto_idx_ours']
            self.pareto_regret_ours = self.data_eval[self.t]['pareto_regret_ours']
        return None
            
    



        
        
